Optimize MLA generate_mask with scatter update#4261
Draft
JHCuc3m wants to merge 1 commit into
Draft
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Replace high-complexity broadcasted comparison in `generate_mask`
with a JAX scatter-based update to prevent OOM and reduce overhead.
- Why: The previous implementation relied on a broadcasted comparison between
the full sequence and the selected top-k indices. At scale (large context
lengths), this created a massive logical intermediate tensor, causing immediate
OOM crashes during local CPU testing. On TPU, even though the XLA compiler successfully
fused the operation to avoid physical memory (HBM) OOM, the hardware remained heavily
bottlenecked by the extreme number of element-wise ALU comparisons executed in registers.
- How: Initializes the mask with `DEFAULT_MASK_VALUE` using `jnp.full`, then
uses advanced indexing to scatter-write `0.0` at the selected `topk_indices`.
This fundamentally changes the algorithm to perform direct writes instead of
comparing all elements, reducing complexity and instruction count.
- Verification:
- Added `test_generate_mask_equivalence` to `attention_test.py` to verify
mathematical equivalence (Passed).
- Ran existing unit tests (Passed).
TAG=agy
CONV=f93063bc-d96c-46f1-9562-20d2a5bf3241
c6af97c to
8d32e21
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR optimizes the
generate_maskfunction in the Multi-head Latent Attention (MLA) sparse indexer by replacing a high-complexity broadcasted comparison with a JAX scatter-based update.The Problem: The previous implementation$O(b \cdot t \cdot k \cdot s)$ . At a 128K context window, which theoretically creates a massive virtual intermediate tensor of ~274 Billion elements (~256 GB). On TPU, although the XLA compiler successfully fused the loop to avoid physical HBM allocation (reducing HBM traffic to ~142 MB), the hardware remained heavily bottlenecked by the sheer volume of element-wise ALU comparisons (274 Billion ops) executed in registers, taking ~153 ms per step.
(jnp.arange(s) == topk_indices[..., None]).any(axis=-2)had an algorithmic complexity ofThe Solution: The optimized version initializes the mask to$O(b \cdot t \cdot s)$ and then uses indexing to scatter-write $O(b \cdot t \cdot k)$ , where $k \ll s$ (2,048 vs 131,072).
DEFAULT_MASK_VALUE0.0at the selectedtopk_indicesImplementation Details: Initializes the mask using
jnp.fulland performs the scatter update using broadcasted batch and time indices (batch_indices = jnp.arange(b)[:, None, None],time_indices = jnp.arange(t)[None, :, None]) viamask.at[batch_indices, time_indices, topk_indices].set(0.0).Tests
1. Mathematical Equivalence Test
Added
test_generate_mask_equivalencetotests/unit/attention_test.pyto verify that the new scatter-based implementation produces identical results to the old broadcast-based implementation.atol=1e-5).2. Regression Testing
Ran the existing unit test suite to ensure no regressions in attention or model functionality:
tests/unit/attention_test.py(Passed)tests/unit/model_test.py(Passed)Checklist
zj-scatter-mask-valis currently running on the cluster).